#pragma once
#include "../common.hpp"
#include "RTreeFunction.hpp"

namespace zzz{
template <typename T>
struct RTreeNodeBase
{
  typedef enum{NODE,LEAF} NODETYPE;
  const NODETYPE type_;  
  RTreeNodeBase(const int level,NODETYPE type, typename RTreeFunction<T>::RTreeFuncGenerator funcgen)
    :level_(level),type_(type),funcgen_(funcgen){}
  virtual ~RTreeNodeBase(){}

  virtual void GrowTrain(const T *v, const int tag)=0;
  virtual void Train(const T *v, const int tag)=0;
  virtual const double *Classify(const T *v) const =0;
  virtual void ClearData(const int ntag)=0;
  virtual void UpdateP()=0;
  virtual void Print(string str)=0;
  int level_;
  typename RTreeFunction<T>::RTreeFuncGenerator funcgen_;
};

//Randomized Tree non-leaf node
template <typename T,int RTREE_NODE_SPLIT>
struct RTreeNode : public RTreeNodeBase<T>
{
  RTreeNode(const int level,int ntag, typename RTreeFunction<T>::RTreeFuncGenerator funcgen):RTreeNodeBase(level,NODE,funcgen),ntag_(ntag)
  {
    Decide=funcgen_(NULL);
    for (int i=0; i<RTREE_NODE_SPLIT; i++)
      node_[i]=new RTreeLeaf<T,RTREE_NODE_SPLIT>(level_+1,ntag_,funcgen_);
  }
  ~RTreeNode()
  {
    for (int i=0; i<RTREE_NODE_SPLIT; i++)
      delete node_[i];
    delete Decide;
  }

  void GrowTrain(const T *v, const int tag){node_[(*Decide)(v)]->GrowTrain(v,tag);}
  void Train(const T *v, const int tag){node_[(*Decide)(v)]->Train(v,tag);}
  const double *Classify(const T *v) const {return node_[(*Decide)(v)]->Classify(v);}

  //grow
  void Grow()
  {
    for (int i=0; i<RTREE_NODE_SPLIT; i++)
    {
      if (node_[i]->type_==NODE) 
        ((RTreeNode<T,RTREE_NODE_SPLIT>*)node_[i])->Grow();
      else //leaf
      {
        if (level_>=MAX_LEVEL-1) return;
        RTreeNode<T,RTREE_NODE_SPLIT>* newnode=((RTreeLeaf<T,RTREE_NODE_SPLIT>*)node_[i])->Grow();
        if (newnode==NULL) continue;
        delete node_[i];
        node_[i]=newnode;
        ((RTreeNode<T,RTREE_NODE_SPLIT>*)node_[i])->Grow();
      }
    }
  }

  //evaluate the partition
  double Gain()
  {
    const double LOG2=log(2.0);
    double E=0;
    int nS=0;
    for (int i=0; i<RTREE_NODE_SPLIT; i++)
    {
      RTreeLeaf<T,RTREE_NODE_SPLIT> *leaf=(RTreeLeaf<T,RTREE_NODE_SPLIT>*)node_[i];
      leaf->UpdateP();
      int nSi=leaf->allcount_;
      if (nSi==0) return 0;
      nS+=nSi;
      double ESi=0;
      for (int j=0; j<ntag_; j++)
        if (leaf->p_[j]!=0) ESi-=leaf->p_[j]*log(leaf->p_[j]);
      ESi/=LOG2;
      E+=nSi*ESi;
    }
    return E/nS;
  }

  // clear data, after grow tree and before train real data
  void ClearData(const int ntag){for (int i=0; i<RTREE_NODE_SPLIT; i++) node_[i]->ClearData(ntag);}

  //update probability, before classify after train
  void UpdateP(){for (int i=0; i<RTREE_NODE_SPLIT; i++) node_[i]->UpdateP();}

  //for debug
  virtual void Print(string str)
  {
    zout<<str;
    zout<<"|-";
    Decide->print();
    zout<<endl;
    for (int i=0; i<RTREE_NODE_SPLIT; i++)
      node_[i]->Print(str+"| ");
  }

  RTreeNodeBase<T> *node_[RTREE_NODE_SPLIT];
  int ntag_;
  RTreeFunction<T> *Decide;
};


//Randomized Tree Leaf
template <typename T,int RTREE_NODE_SPLIT>
struct RTreeLeaf : public RTreeNodeBase<T>
{
  RTreeLeaf(const int level, const int ntag, typename RTreeFunction<T>::RTreeFuncGenerator funcgen)
    :RTreeNodeBase(level,LEAF,funcgen)
  {
    ntag_=ntag;
    count_=new int[ntag];
    memset(count_,0,sizeof(int)*ntag);
    p_=new double[ntag];
  }
  ~RTreeLeaf()
  {
    delete[] count_;
    delete[] p_;
    growtraindata_.clear();
  }

  void ClearData(const int ntag)
  {
    if (ntag!=ntag_)
    {
      delete[] count_;
      delete[] p_;
      count_=new int[ntag];
      p_=new double[ntag];
    }
    growtraindata_.clear();
    ntag_=ntag;
    memset(count_,0,sizeof(int)*ntag);
  }

  void UpdateP()
  {
    availntag_=0;
    allcount_=0;
    for (int i=0; i<ntag_; i++) 
      if (count_[i]!=0) 
      {
        availntag_++;
        allcount_+=count_[i];
      }
    for (int i=0; i<ntag_; i++) 
      if (count_[i]!=0)
        p_[i]=count_[i]/(double)allcount_;
      else
        p_[i]=0;
  }

  //////////////////////////////////////////////////////////////////////////
  //train while remember the data, for growing
  void GrowTrain(const T *v, const int tag)
  {
    Train(v,tag);
    growtraindata_.push_back(make_pair(v,tag));
  }

  //this leaf grow to a node which hold new leaves
  RTreeNode<T,RTREE_NODE_SPLIT> *Grow()
  {
    // too few to grow
    if (availntag_<=MIN_NELEMENT) return NULL;

    int n; //try time
    if (level_==0) n=10;
    else n=100*level_;

    RTreeNode<T,RTREE_NODE_SPLIT> *newnode=NULL;
    double MaxGain=0;
    for (int i=0; i<n; i++)
    {
      //build new node
      RTreeNode<T,RTREE_NODE_SPLIT> *thisnode=new RTreeNode<T,RTREE_NODE_SPLIT>(level_,ntag_,funcgen_);
      int growtraindatasize=growtraindata_.size();
      for (int j=0; j<growtraindatasize; j++) thisnode->GrowTrain(growtraindata_[j].first,growtraindata_[j].second);
      double thisgain=thisnode->Gain();
      if (thisgain>MaxGain)
      {
        MaxGain=thisgain;
        delete newnode;
        newnode=thisnode;
      }
      else
        delete thisnode;
    }
    if (newnode) newnode->Grow();
    return newnode;
  }


  //////////////////////////////////////////////////////////////////////////
  //real train
  void Train(const T *v, const int tag){count_[tag]++;}
  const double *Classify(const T *v) const{return p_;}
  virtual void Print(string str)
  {
    zout<<str;
    zout<<"|-";
    zout<<'['<<allcount_<<']'<<endl;
  }

  int *count_;
  double *p_;
  int allcount_;
  int ntag_;
  int availntag_;
  vector<pair<const T*,int> > growtraindata_;
};
}